{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 使用 CNN卷积神经网络 完成图像分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 数据加载与归一\n",
    "\n",
    "1. 数据导入神经网络 （数据加载），对神经网络进行训练\n",
    "    数据集 CIFAR10 含10类6万张图 (学习阶段的小型图像数据集)\n",
    "          ImageNet 含1000个类超100万张图\n",
    "数据归一\n",
    " -  输入数据变成[0,1] 或 [-1，1]之间\n",
    " - 图像数据像素值一般[0，255]\n",
    "\n",
    "#### Pytorch数据加载： torchvision.dataset\n",
    "#### Pytorch数据归一： torchvision.transforms"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])\n",
    "# Normalize的2个参数， mean 三个通道的平均值, std 三个通道的方差\n",
    "\n",
    "# 数据增强方式：t\n",
    "\n",
    "# 训练集\n",
    "train_set = torchvision.datasets.CIFAR10(root='./data',train=True, download=False, transform=transform)\n",
    "# 下载cifar10到data目录，作为训练集\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True,num_workers=2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 测试集\n",
    "test_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False,transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False,num_workers=2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 加载自己的数据集\n",
    "dir_images = ''\n",
    "# private_set = torchvision.datasets.ImageFolder(root=dir_images,transform=transform)\n",
    "# private_data_loader = torchvision."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "\n",
    "def im_show(img):\n",
    "    # 输入数据 torch.tensor [c h w]\n",
    "    img = img /2 + 0.5\n",
    "    nping = img.numpy()\n",
    "    nping = np.transpose(nping, (1,2,0)) # [h w c] 转置\n",
    "    plt.imshow(nping)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data_iter = iter(train_loader)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "images,labels = data_iter.next() # 获取batch size图片显示\n",
    "im_show(torchvision.utils.make_grid(images))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Net(nn.Module):\n",
    "    \"\"\"定义神经网络结构，输入数据 3*32*32 \"\"\"\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # 第1层 卷积层\n",
    "        self.conv1 = nn.Conv2d(3,6,3) # 输入频道3，输出频道6，卷积3*3\n",
    "        # 第2层 卷积层\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3) # 输入频道6， 输出频道16， 卷积3*3\n",
    "        # 第3层 全连接层\n",
    "        self.fc1 = nn.Linear(16*28*28, 512) # 输入维度 16*28*28 输出维度 512\n",
    "        # 第4层 全连接层\n",
    "        self.fc2 = nn.Linear(512, 64) #\n",
    "        # 第5层 全连接层\n",
    "        self.fc3 = nn.Linear(64, 10) # 输出10 10个类\n",
    "    def forward(self,x): # 定义数据流向\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x) # 使用激活函数固定数据到一个范围\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "\n",
    "        x = x.view(-1, 16*28*28) # 展开\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "\n",
    "        x = self.fc2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "net = Net()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9dbb54d670>\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/utils/data/dataloader.py\", line 1324, in __del__\n",
      "    self._shutdown_workers()\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/utils/data/dataloader.py\", line 1297, in _shutdown_workers\n",
      "    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/multiprocessing/process.py\", line 149, in join\n",
      "    res = self._popen.wait(timeout)\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/multiprocessing/popen_fork.py\", line 44, in wait\n",
      "    if not wait([self.sentinel], timeout):\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/multiprocessing/connection.py\", line 931, in wait\n",
      "    ready = selector.select(timeout)\n",
      "  File \"/Users/edgar/miniconda3/envs/juliang/lib/python3.8/selectors.py\", line 415, in select\n",
      "    fd_event_list = self._selector.poll(timeout)\n",
      "KeyboardInterrupt: \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-12-ecad098d87df>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m      7\u001B[0m         \u001B[0moptimizer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mzero_grad\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      8\u001B[0m         \u001B[0mloss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 9\u001B[0;31m         \u001B[0moptimizer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     10\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0mi\u001B[0m \u001B[0;34m%\u001B[0m \u001B[0;36m1000\u001B[0m \u001B[0;34m==\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     11\u001B[0m             \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34mf'epoch:{epoch}, step: {i}, loss:{loss.item():.3f} '\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/optim/optimizer.py\u001B[0m in \u001B[0;36mwrapper\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m     87\u001B[0m                 \u001B[0mprofile_name\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m\"Optimizer.step#{}.step\"\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mobj\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__class__\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__name__\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     88\u001B[0m                 \u001B[0;32mwith\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mautograd\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mprofiler\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrecord_function\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mprofile_name\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 89\u001B[0;31m                     \u001B[0;32mreturn\u001B[0m \u001B[0mfunc\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     90\u001B[0m             \u001B[0;32mreturn\u001B[0m \u001B[0mwrapper\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     91\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001B[0m in \u001B[0;36mdecorate_context\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m     25\u001B[0m         \u001B[0;32mdef\u001B[0m \u001B[0mdecorate_context\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     26\u001B[0m             \u001B[0;32mwith\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__class__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 27\u001B[0;31m                 \u001B[0;32mreturn\u001B[0m \u001B[0mfunc\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     28\u001B[0m         \u001B[0;32mreturn\u001B[0m \u001B[0mcast\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mF\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdecorate_context\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     29\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/optim/sgd.py\u001B[0m in \u001B[0;36mstep\u001B[0;34m(self, closure)\u001B[0m\n\u001B[1;32m    108\u001B[0m                         \u001B[0mmomentum_buffer_list\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mstate\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'momentum_buffer'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    109\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 110\u001B[0;31m             F.sgd(params_with_grad,\n\u001B[0m\u001B[1;32m    111\u001B[0m                   \u001B[0md_p_list\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    112\u001B[0m                   \u001B[0mmomentum_buffer_list\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/miniconda3/envs/juliang/lib/python3.8/site-packages/torch/optim/_functional.py\u001B[0m in \u001B[0;36msgd\u001B[0;34m(params, d_p_list, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov)\u001B[0m\n\u001B[1;32m    167\u001B[0m                 \u001B[0mmomentum_buffer_list\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mi\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mbuf\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    168\u001B[0m             \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 169\u001B[0;31m                 \u001B[0mbuf\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmul_\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmomentum\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0madd_\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0md_p\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0malpha\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m1\u001B[0m \u001B[0;34m-\u001B[0m \u001B[0mdampening\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    170\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    171\u001B[0m             \u001B[0;32mif\u001B[0m \u001B[0mnesterov\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "epochs = 2\n",
    "for epoch in range(epochs):\n",
    "    for i , data in enumerate(train_loader):\n",
    "        images, labels = data\n",
    "        output = net(images)\n",
    "        loss = criterion(output, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if i % 1000 == 0:\n",
    "            print(f'epoch:{epoch}, step: {i}, loss:{loss.item():.3f} ')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 训练测试2\n",
    "train_loss_hist = []\n",
    "test_loss_hist = []\n",
    "epochs = 20\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    # 训练\n",
    "    net.train()\n",
    "    train_loss = 0.0\n",
    "    for i , data in enumerate(train_loader):\n",
    "        images, labels = data\n",
    "        output = net(images)\n",
    "        loss = criterion(output, labels) # 计算损失\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item()\n",
    "        if i % 250 == 0: # 每250 mini batch 测试一次\n",
    "            correct = 0.0\n",
    "            total = 0.0\n",
    "            net.eval()\n",
    "            with torch.no_grad():\n",
    "                for test_data in test_loader:\n",
    "                    test_images, test_labels = test_data\n",
    "                    test_output = net(test_images)\n",
    "                    test_loss = criterion(test_output, test_labels) # 计算损失\n",
    "\n",
    "            train_loss_hist.append(train_loss / 250 )\n",
    "            test_loss_hist.append(test_loss.item())\n",
    "            train_loss = 0.0\n",
    "            # print(f'epoch:{epoch}, step: {i}, loss:{loss.item():.3f} ')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(train_loss_hist)\n",
    "plt.plot(test_loss_hist)\n",
    "plt.legend('train_loss_hist','test_loss_hist')\n",
    "plt.title('Loss')\n",
    "plt.xlabel('#mini batch *250')\n",
    "plt.ylabel('Loss')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 模型测试\n",
    "correct = 0.0\n",
    "total = 0.0\n",
    "with torch.no_grad():\n",
    "    for test_data in test_loader:\n",
    "        test_images, test_labels = test_data\n",
    "        test_output = net(test_images)\n",
    "        _,predicted = torch.max(test_output.data,1)\n",
    "        correct += (predicted == test_labels).sum()\n",
    "        total +=labels.size(0)\n",
    "\n",
    "print('准确率', float(correct/total))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 保存模型\n",
    "torch.save(net.state_dict(), './model.pt')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 读取模型文件\n",
    "net2 = Net()\n",
    "net2.load_state_dict(torch.load('./model.pt'))\n",
    "\n",
    "# 模型测试\n",
    "correct = 0.0\n",
    "total = 0.0\n",
    "with torch.no_grad():\n",
    "    for test_data in test_loader:\n",
    "        test_images, test_labels = test_data\n",
    "        test_output = net2(test_images)\n",
    "        _,predicted = torch.max(test_output.data,1)\n",
    "        correct += (predicted == test_labels).sum()\n",
    "        total +=labels.size(0)\n",
    "\n",
    "print('准确率', float(correct/total))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}